"""
Phase 4c: Probing the High-Income vs Low-Income Disparity
============================================================
Phase 4 found Z polynomials significant for low-income but not high-income
countries — opposite of the canonical Japanification story.

Three competing stories:
  D. Restricted variation (rich countries all aging similarly)
  E. Policy buffers (institutions absorb the demographic shock)
  F. Different mechanism (demographic dividend ending ≠ aging per se)

Tests:
  6.  Variation diagnostic (within-group Z spread)
  7.  Institutional interaction (KAOPEN as proxy for institutional quality)
  8.  Working-age share decomposition
  9.  Matched sample (same OADR range across income groups)
  10. Continuous income interaction (at what income does the effect fade?)

Input:  japanification/data/processed/japan_panel_indexed.csv
Output: japanification/output/tables/phase4c_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS with printed summary."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n  {label}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 4c: Probing the Income Disparity")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    base_vars = demo_vars + controls
    dep_var = 'japan_index_2c'

    all_results = []

    # Compute income groupings
    df['log_gdp_pc'] = np.log(df['gdp_pc_ppp'].clip(lower=100))
    median_gdp = df['gdp_pc_ppp'].median()
    df['high_income'] = (df['gdp_pc_ppp'] >= median_gdp).astype(int)

    # =================================================================
    # TEST 6: Variation diagnostic
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 6: Variation Diagnostic")
    print("=" * 70)

    print("\n  Within-group variation of demographic variables:")
    print(f"  {'Variable':<20} {'High-Inc SD':>12} {'Low-Inc SD':>12} {'Ratio':>8} {'HI range':>18} {'LI range':>18}")
    print("  " + "-" * 90)

    var_diag_rows = []
    for v in demo_vars + ['old_dep', 'working_age_share', 'life_expectancy']:
        if v not in df.columns:
            continue
        hi = df[df['high_income'] == 1][v].dropna()
        li = df[df['high_income'] == 0][v].dropna()
        if len(hi) == 0 or len(li) == 0:
            continue

        ratio = hi.std() / li.std() if li.std() > 0 else np.nan
        hi_range = f"[{hi.min():.2f}, {hi.max():.2f}]"
        li_range = f"[{li.min():.2f}, {li.max():.2f}]"

        print(f"  {v:<20} {hi.std():>12.3f} {li.std():>12.3f} {ratio:>8.2f} {hi_range:>18} {li_range:>18}")

        var_diag_rows.append({
            'variable': v,
            'high_income_sd': hi.std(),
            'low_income_sd': li.std(),
            'ratio': ratio,
            'high_income_n': len(hi),
            'low_income_n': len(li),
            'high_income_mean': hi.mean(),
            'low_income_mean': li.mean(),
        })

    var_diag = pd.DataFrame(var_diag_rows)
    var_diag.to_csv(TABLE_DIR / "phase4c_variation_diagnostic.csv", index=False)

    # Cross-sectional variation (between-country, within-year averages)
    print("\n  Between-country variation (country means):")
    country_means = df.groupby(['iso3', 'high_income'])[demo_vars].mean().reset_index()
    for v in demo_vars:
        hi_sd = country_means[country_means['high_income'] == 1][v].std()
        li_sd = country_means[country_means['high_income'] == 0][v].std()
        ratio = hi_sd / li_sd if li_sd > 0 else np.nan
        print(f"    {v}: HI between-country SD={hi_sd:.3f}, LI={li_sd:.3f}, ratio={ratio:.2f}")

    # =================================================================
    # TEST 7: Institutional interaction
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 7: Institutional Interaction")
    print("=" * 70)

    # Use KAOPEN as institutional quality proxy (captures financial development,
    # rule of law, and policy sophistication)
    # Already have Z×KAOPEN from Phase 3 — here test more directly

    # Create KAOPEN terciles
    kaopen_terciles = df['kaopen'].quantile([0.33, 0.67])
    df['kaopen_group'] = 'low'
    df.loc[df['kaopen'] > kaopen_terciles[0.33], 'kaopen_group'] = 'mid'
    df.loc[df['kaopen'] > kaopen_terciles[0.67], 'kaopen_group'] = 'high'

    for group in ['low', 'mid', 'high']:
        sub = df[df['kaopen_group'] == group].dropna(subset=[dep_var] + base_vars)
        if len(sub) >= 100:
            m, r = fit_and_report(
                sub[dep_var].values, sub[base_vars].values,
                sub['iso3'].values, sub['year'].values,
                base_vars, f"KAOPEN tercile: {group}"
            )
            all_results.append(r)

    # NFA position as proxy for financial development/creditor status
    # Creditor nations have accumulated institutional capacity
    df['nfa_positive_dummy'] = (df['nfa_gdp_lag'] > 0).astype(int) if 'nfa_gdp_lag' in df.columns else 0
    for zv in demo_vars:
        df[f'{zv}_x_nfa_pos'] = df[zv] * df['nfa_positive_dummy']

    nfa_int_vars = [f'{zv}_x_nfa_pos' for zv in demo_vars]
    nfa_vars = base_vars + ['nfa_positive_dummy'] + nfa_int_vars
    est_nfa = df.dropna(subset=[dep_var] + nfa_vars)
    if len(est_nfa) >= 200:
        m_nfa, r_nfa = fit_and_report(
            est_nfa[dep_var].values, est_nfa[nfa_vars].values,
            est_nfa['iso3'].values, est_nfa['year'].values,
            nfa_vars, "NFA creditor interaction (institutional proxy)"
        )
        all_results.append(r_nfa)

    # =================================================================
    # TEST 8: Working-age share decomposition
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 8: Working-Age Share Decomposition")
    print("=" * 70)

    # Does the Z effect in low-income countries survive controlling for
    # working_age_share? If not, the effect is really about the
    # demographic dividend ending, not aging per se.

    if 'working_age_share' in df.columns:
        # Low-income: Z alone
        li_df = df[df['high_income'] == 0].dropna(subset=[dep_var] + base_vars)
        m_li_base, r_li_base = fit_and_report(
            li_df[dep_var].values, li_df[base_vars].values,
            li_df['iso3'].values, li_df['year'].values,
            base_vars, "Low-income: Z polynomials only"
        )
        all_results.append(r_li_base)

        # Low-income: Z + working_age_share
        was_vars = base_vars + ['working_age_share']
        li_was = df[df['high_income'] == 0].dropna(subset=[dep_var] + was_vars)
        if len(li_was) >= 100:
            m_li_was, r_li_was = fit_and_report(
                li_was[dep_var].values, li_was[was_vars].values,
                li_was['iso3'].values, li_was['year'].values,
                was_vars, "Low-income: Z + working_age_share"
            )
            all_results.append(r_li_was)

            # Compare Z_1 coefficient
            z1_base = m_li_base.beta[0]
            z1_was = m_li_was.beta[0]
            p_base = m_li_base.pvalues[0]
            p_was = m_li_was.pvalues[0]
            print(f"\n  DIAGNOSIS: Does working_age_share absorb the Z effect?")
            print(f"    Z₁ without WAS: {z1_base:.3f} (p={p_base:.3f})")
            print(f"    Z₁ with WAS:    {z1_was:.3f} (p={p_was:.3f})")
            reduction = abs(z1_base - z1_was) / abs(z1_base) * 100 if z1_base != 0 else 0
            print(f"    Reduction: {reduction:.0f}%")
            if reduction > 50:
                print(f"    → Story F supported: Z effect proxies for demographic dividend ending")
            else:
                print(f"    → Story F not supported: Z effect survives controlling for WAS")

        # High-income: same exercise
        hi_df = df[df['high_income'] == 1].dropna(subset=[dep_var] + base_vars)
        m_hi_base, r_hi_base = fit_and_report(
            hi_df[dep_var].values, hi_df[base_vars].values,
            hi_df['iso3'].values, hi_df['year'].values,
            base_vars, "High-income: Z polynomials only"
        )
        all_results.append(r_hi_base)

        hi_was = df[df['high_income'] == 1].dropna(subset=[dep_var] + was_vars)
        if len(hi_was) >= 100:
            m_hi_was, r_hi_was = fit_and_report(
                hi_was[dep_var].values, hi_was[was_vars].values,
                hi_was['iso3'].values, hi_was['year'].values,
                was_vars, "High-income: Z + working_age_share"
            )
            all_results.append(r_hi_was)

    # =================================================================
    # TEST 9: Matched sample (same OADR range)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 9: Matched Sample (Common OADR Range)")
    print("=" * 70)

    # Find the overlapping OADR range between income groups
    hi_oadr = df[df['high_income'] == 1]['old_dep'].dropna()
    li_oadr = df[df['high_income'] == 0]['old_dep'].dropna()
    overlap_lo = max(hi_oadr.quantile(0.05), li_oadr.quantile(0.05))
    overlap_hi = min(hi_oadr.quantile(0.95), li_oadr.quantile(0.95))

    print(f"  OADR overlap range: [{overlap_lo:.3f}, {overlap_hi:.3f}]")
    print(f"  High-income full range: [{hi_oadr.min():.3f}, {hi_oadr.max():.3f}]")
    print(f"  Low-income full range:  [{li_oadr.min():.3f}, {li_oadr.max():.3f}]")

    matched = df[(df['old_dep'] >= overlap_lo) & (df['old_dep'] <= overlap_hi)]
    print(f"  Matched sample: {len(matched):,} obs, {matched['iso3'].nunique()} countries")

    for inc_label, inc_val in [('Matched high-income', 1), ('Matched low-income', 0)]:
        sub = matched[matched['high_income'] == inc_val].dropna(subset=[dep_var] + base_vars)
        if len(sub) >= 100:
            m_match, r_match = fit_and_report(
                sub[dep_var].values, sub[base_vars].values,
                sub['iso3'].values, sub['year'].values,
                base_vars, inc_label
            )
            all_results.append(r_match)

    # Matched pooled with income dummy and interaction
    matched_est = matched.dropna(subset=[dep_var] + base_vars + ['high_income'])
    if len(matched_est) >= 200:
        for zv in demo_vars:
            matched_est[f'{zv}_x_hi'] = matched_est[zv] * matched_est['high_income']

        hi_int = [f'{zv}_x_hi' for zv in demo_vars]
        matched_vars = base_vars + ['high_income'] + hi_int
        m_pool, r_pool = fit_and_report(
            matched_est[dep_var].values, matched_est[matched_vars].values,
            matched_est['iso3'].values, matched_est['year'].values,
            matched_vars, "Matched sample: pooled with income interaction"
        )
        all_results.append(r_pool)

        # Report the differential
        print(f"\n  Z₁ for low-income (base): {m_pool.beta[0]:.3f}")
        z1_hi_idx = matched_vars.index('Z_1_x_hi')
        print(f"  Z₁ differential for high-income: {m_pool.beta[z1_hi_idx]:.3f} "
              f"(p={m_pool.pvalues[z1_hi_idx]:.3f})")
        print(f"  Z₁ for high-income (total): {m_pool.beta[0] + m_pool.beta[z1_hi_idx]:.3f}")

    # =================================================================
    # TEST 10: Continuous income interaction
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 10: Continuous Income Interaction")
    print("=" * 70)

    # Test Z₁ alone × log GDP/cap (Phase 4 tested all three simultaneously)
    if 'log_gdp_pc' in df.columns:
        # Z₁ × income only
        df['Z_1_x_lgdppc'] = df['Z_1'] * df['log_gdp_pc']
        z1_inc_vars = ['Z_1', 'Z_2', 'Z_3', 'log_gdp_pc', 'Z_1_x_lgdppc'] + controls
        est_z1i = df.dropna(subset=[dep_var] + z1_inc_vars)

        if len(est_z1i) >= 200:
            m_z1i, r_z1i = fit_and_report(
                est_z1i[dep_var].values, est_z1i[z1_inc_vars].values,
                est_z1i['iso3'].values, est_z1i['year'].values,
                z1_inc_vars, "Z₁ × log(GDP/cap) only"
            )
            all_results.append(r_z1i)

        # Income quartile regressions for finer granularity
        print("\n  Income Quartile Regressions:")
        df['income_quartile'] = pd.qcut(df['gdp_pc_ppp'], 4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
        quartile_z1 = []
        for q in ['Q1', 'Q2', 'Q3', 'Q4']:
            sub = df[df['income_quartile'] == q].dropna(subset=[dep_var] + base_vars)
            if len(sub) >= 100:
                m_q, r_q = fit_and_report(
                    sub[dep_var].values, sub[base_vars].values,
                    sub['iso3'].values, sub['year'].values,
                    base_vars, f"Income quartile {q}"
                )
                all_results.append(r_q)

                z1_coef = m_q.beta[0]
                z1_p = m_q.pvalues[0]
                median_inc = sub['gdp_pc_ppp'].median()
                quartile_z1.append({
                    'quartile': q,
                    'median_gdp_pc': median_inc,
                    'Z_1_coef': z1_coef,
                    'Z_1_pval': z1_p,
                    'N': m_q.n_obs,
                    'N_countries': m_q.n_countries,
                    'R_squared': m_q.r_squared,
                })

        if quartile_z1:
            q_df = pd.DataFrame(quartile_z1)
            q_df.to_csv(TABLE_DIR / "phase4c_income_quartiles.csv", index=False)
            print("\n  Z₁ by income quartile:")
            print(f"  {'Q':>4} {'Median GDP/cap':>15} {'Z₁ coef':>10} {'p-val':>8} {'N':>6} {'R²':>6}")
            print("  " + "-" * 55)
            for _, row in q_df.iterrows():
                sig = '***' if row['Z_1_pval'] < 0.01 else '**' if row['Z_1_pval'] < 0.05 else '*' if row['Z_1_pval'] < 0.1 else ''
                print(f"  {row['quartile']:>4} {row['median_gdp_pc']:>15,.0f} {row['Z_1_coef']:>10.3f} "
                      f"{row['Z_1_pval']:>7.3f}{sig} {row['N']:>5.0f} {row['R_squared']:>6.3f}")

            # Find the crossover point
            pos_qs = q_df[q_df['Z_1_coef'] > 0]
            neg_qs = q_df[q_df['Z_1_coef'] < 0]
            if len(pos_qs) > 0 and len(neg_qs) > 0:
                crossover = (pos_qs['median_gdp_pc'].max() + neg_qs['median_gdp_pc'].min()) / 2
                print(f"\n  Approximate crossover: ~${crossover:,.0f} GDP/cap (PPP)")
                print(f"  Below this: aging → more Japanification")
                print(f"  Above this: aging → less Japanification (institutions resist)")

    # =================================================================
    # ADDITIONAL: Pre-GFC income split (is it a period effect?)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  ADDITIONAL: Income Split × Period Interaction")
    print("=" * 70)

    for period_label, period_mask in [('Pre-GFC', df['year'] <= 2007),
                                       ('Post-GFC', df['year'] >= 2009)]:
        for inc_label, inc_mask in [('High-income', df['high_income'] == 1),
                                     ('Low-income', df['high_income'] == 0)]:
            sub = df[period_mask & inc_mask].dropna(subset=[dep_var] + base_vars)
            if len(sub) >= 80:
                m_pi, r_pi = fit_and_report(
                    sub[dep_var].values, sub[base_vars].values,
                    sub['iso3'].values, sub['year'].values,
                    base_vars, f"{inc_label} × {period_label}"
                )
                all_results.append(r_pi)

    # =================================================================
    # SYNTHESIS
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  SYNTHESIS: What Explains the Income Disparity?")
    print("=" * 70)

    results_df = pd.concat(all_results, ignore_index=True)

    # Z_1 summary
    z1_summary = (results_df[results_df['variable'] == 'Z_1']
                  [['model', 'coefficient', 'std_error', 'p_value']]
                  .sort_values('model'))
    print("\n  Z₁ coefficient across all income disparity tests:")
    for _, row in z1_summary.iterrows():
        sig = '***' if row['p_value'] < 0.01 else '**' if row['p_value'] < 0.05 else '*' if row['p_value'] < 0.1 else ''
        print(f"    {row['model']:<50} {row['coefficient']:>8.3f} (p={row['p_value']:.3f}){sig}")

    # Save
    results_df.to_csv(TABLE_DIR / "phase4c_income_disparity_results.csv", index=False)
    z1_summary.to_csv(TABLE_DIR / "phase4c_z1_summary.csv", index=False)
    print(f"\nSaved to {TABLE_DIR / 'phase4c_*.csv'}")


if __name__ == "__main__":
    main()
